-
Notifications
You must be signed in to change notification settings - Fork 420
Support chip count change when loading grain checkpoint #2537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
I see merge conflict. It's an known issue for not triggering gemini review. |
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Aireen! Do you think we could get reviewed by Grain team as well?
da793c9 to
9041d87
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @aireenmei, but I was unable to process your request. Please see the logs for more details. |
78bbef3 to
ef3d6d8
Compare
75702a5 to
ec2c319
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a sophisticated feature to handle changes in chip count when resuming training with Grain checkpoints. The implementation is thorough, touching upon checkpointing, data loading, and configuration to support both scaling up and scaling down scenarios. The logic is well-structured, especially in the new GrainCheckpointHandler and the restore strategies in _restore_grain_iterator.
🔍 General Feedback
- The addition of documentation in
data_input_grain.mdis very helpful for understanding this complex feature. - The code is well-organized, and the separation of concerns for handling different scaling scenarios is clear.
- The changes to wait for the checkpoint manager to finish at the end of training in
train.pyand other trainer scripts are a good reliability improvement.
Overall, this is a solid contribution that adds significant flexibility to the training pipeline. I have one minor suggestion for improving clarity.
f555cfd to
89571b6
Compare
a107cd2 to
2b552a4
Compare
2b552a4 to
e7128bc
Compare
Description
See the added paragraph in data_input_grain.md in this PR for description of this behavior.
Also in b/452377649
Tests
expansion_factor_real_data=2, max_checkify=true (to make sure the PlaceHolderIterator's fake data are not used for training), restored the checkpoint from step=3, run until step=9 while saving checkpoints on step=6 and step=9 (log)Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.